Retinal fundus image registration

Author

Charles O’Neill

Published

Invalid Date

Retinal fundus image registration

Setup

First, we need to import the appropriate modules.

Code
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
from tqdm.notebook import tqdm
from skimage.registration import optical_flow_tvl1, optical_flow_ilk
from skimage.transform import warp
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import normalized_root_mse as nrmse

Next, let’s write a function to handle retrieving the images. Since we only want to register fundus images to the same eye, we specify which patient and which laterality we want to load in:

Code
def retrieve_images(patient_id = '156518', laterality = 'L'):
    # Set the root directory for the patient data
    root_dir = Path(f'../data/{patient_id}')

    # Get the list of image filenames for the left eye
    image_filenames = [f for f in os.listdir(root_dir) if f'{laterality}.png' in f]

    # Read the images into a list
    images = [cv2.imread(str(root_dir / f)) for f in image_filenames]

    # Convert the images to grayscale
    gray_images = [rgb2gray(img) for img in images]

    # Register all images to the first image
    template = gray_images[0]

    # Remove invalid images
    final_images = [x for x in gray_images[1:] if x.shape == template.shape]

    return final_images, template

When evaluating our registration algorithm, our evaluation metric will be some function that computes the distance between the registered images and the template image. We want to be able to track a few of these metrics. Some common ones include:

  • L1 loss, also known as mean absolute error, measures the average magnitude of the element-wise differences between two images. It is robust to outliers and gives equal weight to all pixels, making it a good choice for image registration.
  • RMSE, or root mean square error, is the square root of the mean of the squared differences between two images. It gives more weight to larger differences, making it sensitive to outliers. RMSE is commonly used in image registration to measure the overall difference between two images.
  • Normalised cross-correlation is a measure of the similarity between two images, taking into account their intensities. It is normalised to ensure that the result is between -1 and 1, where 1 indicates a perfect match. Normalised cross-correlation is often used in image registration to assess the quality of the registration, especially when dealing with images with different intensities.
  • Similarity is a measure of the overlap between two images, taking into account both the intensities and spatial information. Common similarity metrics used in image registration include mutual information, normalised mutual information, and the Jensen-Shannon divergence. These metrics provide a measure of the information shared between two images, making them well suited for assessing the quality of image registration.

The following function takes a list of registered images, as well as the template image, and calculates the above metrics for each image:

Code
def evaluate_registration(template_img, registered_imgs):
    l1_losses = []
    ncc_values = []
    ssim_values = []
    
    for registered_img in registered_imgs:
        l1_loss = np.mean(np.abs(template_img - registered_img))
        l1_losses.append(l1_loss)
        
        ncc = np.corrcoef(template_img.ravel(), registered_img.ravel())[0,1]
        ncc_values.append(ncc)
        
        ssim_value = ssim(template_img, registered_img, data_range=registered_img.max() - registered_img.min())
        ssim_values.append(ssim_value)
        
    return l1_losses, ncc_values, ssim_values

Given these losses, it’s probably a good idea to have some sort of function that shows us the best and worst registered images, based on the loss. This is somewhat similar to viewing individual examples from a confusion matrix in a classification task.

Code
def visualise_registration_results(registered_images, original_images, template, loss_values):
    # Get the indices of the three images with the highest L1 losses
    top_indices = np.argsort(loss_values)[-3:]

    # Get the indices of the three images with the lowest L1 losses
    bottom_indices = np.argsort(loss_values)[:3]

    # Create the grid figure
    fig, axes = plt.subplots(3, 4, figsize=(20, 15))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)

    # Loop through the top three images
    for i, idx in enumerate(top_indices):
        # Plot the original image in the first column of the left section
        ax = axes[i][0]
        ax.imshow(original_images[idx], cmap='gray')
        ax.set_title("Original Image")

        # Plot the registered image in the second column of the left section
        ax = axes[i][1]
        ax.imshow(registered_images[idx], cmap='gray')
        ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))

    # Loop through the bottom three images
    for i, idx in enumerate(bottom_indices):
        # Plot the original image in the first column of the right section
        ax = axes[i][2]
        ax.imshow(original_images[idx], cmap='gray')
        ax.set_title("Original Image")

        # Plot the registered image in the second column of the right section
        ax = axes[i][3]
        ax.imshow(registered_images[idx], cmap='gray')
        ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))

    # Show the grid
    plt.show()

Exploratory data analysis

Code
# Define the path to the directory containing the fundus images
fundus_dir = Path("../data/156518")

# Filter the files in the directory to only include PNG files
fundus_images = [x for x in fundus_dir.iterdir() if x.is_file() and x.suffix == ".png"]
fundus_images = [x for x in fundus_images if "_L.png" in x.name]

# Get the first fundus image
first_image = fundus_images[0]

# Load the first image using Matplotlib's imread function
image = plt.imread(first_image)

# Display the first image
plt.imshow(image, cmap='gray')
plt.show()

Code
# Define the number of rows and columns for the subplots
nrows, ncols = 3, 3

# Create the subplots
fig, axs = plt.subplots(nrows, ncols, figsize=(15, 15))
axs = axs.ravel()

# Load and display the first 9 images
for i in range(9):
    image = plt.imread(fundus_images[i])
    axs[i].imshow(image, cmap='gray')
    axs[i].axis("off")

plt.tight_layout()
plt.show()

Code
# Set the root directory for the patient data
root_dir = Path('../data/156518')

# Get the list of image filenames for the left eye
image_filenames = [f for f in os.listdir(root_dir) if 'L.png' in f]

# Read the images into a list
images = [cv2.imread(str(root_dir / f)) for f in image_filenames]

# Convert the images to grayscale
gray_images = [rgb2gray(img) for img in images]

# Register all images to the first image
template = gray_images[0]

# Remove invalid images
final_images = [x for x in gray_images[1:] if x.shape == template.shape]

# Do the registration process
registered_images = []
for i, img in enumerate(tqdm(final_images)):

        # calculate the vector field for optical flow
        v, u = optical_flow_tvl1(template, img)
        # use the estimated optical flow for registration
        nr, nc = template.shape
        row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc),
                                             indexing='ij')
        registered = warp(img, np.array([row_coords + v, col_coords + u]), mode='edge')
        #registered = cv2.warpAffine(img, flow, (img.shape[1], img.shape[0]))
        registered_images.append(registered)
Code
fig, axs = plt.subplots(1, 3, figsize=(20, 20))
axs = axs.ravel()

i = 2
axs[0].imshow(final_images[i], cmap='gray')
axs[0].set_xlabel("Previous image")
axs[1].imshow(registered_images[i], cmap='gray')
axs[1].set_xlabel("Registered image")
axs[2].imshow(template, cmap='gray')
axs[2].set_xlabel("Template")

plt.show()

Code
def evaluate_registration(template_img, registered_imgs):
    l1_losses = []
    ncc_values = []
    ssim_values = []
    
    for registered_img in registered_imgs:
        l1_loss = np.mean(np.abs(template_img - registered_img))
        l1_losses.append(l1_loss)
        
        ncc = np.corrcoef(template_img.ravel(), registered_img.ravel())[0,1]
        ncc_values.append(ncc)
        
        ssim_value = ssim(template_img, registered_img, data_range=registered_img.max() - registered_img.min())
        ssim_values.append(ssim_value)
        
    return l1_losses, ncc_values, ssim_values


l1_losses, ncc_values, ssim_values = evaluate_registration(template, registered_images)

print("L1 losses:", np.mean(l1_losses))
print("Normalized cross-correlation values:", np.mean(ncc_values))
print("Structural similarity index values:", np.mean(ssim_values))
L1 losses: 0.08169580604522776
Normalized cross-correlation values: 0.7297043703588232
Structural similarity index values: 0.6343061030028727
Code
def visualise_registration_results(registered_images, original_images, template, loss_values):
    # Get the indices of the three images with the highest L1 losses
    top_indices = np.argsort(loss_values)[-3:]

    # Get the indices of the three images with the lowest L1 losses
    bottom_indices = np.argsort(loss_values)[:3]

    # Create the grid figure
    fig, axes = plt.subplots(3, 4, figsize=(20, 15))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)

    # Loop through the top three images
    for i, idx in enumerate(top_indices):
        # Plot the original image in the first column of the left section
        ax = axes[i][0]
        ax.imshow(original_images[idx], cmap='gray')
        ax.set_title("Original Image")

        # Plot the registered image in the second column of the left section
        ax = axes[i][1]
        ax.imshow(registered_images[idx], cmap='gray')
        ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))

    # Loop through the bottom three images
    for i, idx in enumerate(bottom_indices):
        # Plot the original image in the first column of the right section
        ax = axes[i][2]
        ax.imshow(original_images[idx], cmap='gray')
        ax.set_title("Original Image")

        # Plot the registered image in the second column of the right section
        ax = axes[i][3]
        ax.imshow(registered_images[idx], cmap='gray')
        ax.set_title("Registered Image (L1 Loss: {:.2f})".format(loss_values[idx]))


    # Show the grid
    plt.show()
Code
registered_images = opt.registered_images
original_images = images

visualise_registration_results(registered_images, original_images, template, l1_losses)

General Python class

Code
def retrieve_images(patient_id = '156518'):
    # Set the root directory for the patient data
    root_dir = Path(f'../data/{patient_id}')

    # Get the list of image filenames for the left eye
    image_filenames = [f for f in os.listdir(root_dir) if 'L.png' in f]

    # Read the images into a list
    images = [cv2.imread(str(root_dir / f)) for f in image_filenames]

    # Convert the images to grayscale
    gray_images = [rgb2gray(img) for img in images]

    # Register all images to the first image
    template = gray_images[0]

    # Remove invalid images
    final_images = [x for x in gray_images[1:] if x.shape == template.shape]

    return final_images, template
Code
def optical_flow(template, img):
    # calculate the vector field for optical flow
    v, u = optical_flow_tvl1(template, img)
    # use the estimated optical flow for registration
    nr, nc = template.shape
    row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc),
                                         indexing='ij')
    registered = warp(img, np.array([row_coords + v, col_coords + u]), mode='edge')
    return registered
Code
import wandb

class RegistrationAlgorithm:
    
    def __init__(self, registration_function):
        self.registration_function = registration_function
        self.final_images, self.template = retrieve_images()
        self.registered_images = self.apply_registration()
        
    def apply_registration(self):
        # Do the registration process
        registered_images = []
        for i, img in enumerate(tqdm(self.final_images)):
            registered = self.registration_function(self.template, img) 
            registered_images.append(registered)
        return registered_images
    
    def evaluate_registration(self):
        l1_losses = []
        ncc_values = []
        ssim_values = []

        for registered_img in self.registered_images:
            l1_loss = np.mean(np.abs(self.template - registered_img))
            l1_losses.append(l1_loss)

            ncc = np.corrcoef(self.template.ravel(), registered_img.ravel())[0,1]
            ncc_values.append(ncc)

            ssim_value = ssim(self.template, registered_img, data_range=registered_img.max() - registered_img.min())
            ssim_values.append(ssim_value)

        return l1_losses, ncc_values, ssim_values
Code
opt = RegistrationAlgorithm(optical_flow)
Code
l1_losses, ncc_values, ssim_values = opt.evaluate_registration()
print("L1 losses:", f"{np.mean(l1_losses):.2f}")
print("Normalized cross-correlation values:", f"{np.mean(ncc_values):.2f}")
print("Structural similarity index values:", f"{np.mean(ssim_values):.2f}")
L1 losses: 0.08
Normalized cross-correlation values: 0.73
Structural similarity index values: 0.63

Optical flow

Code
def optical_flow(template, img):
    # calculate the vector field for optical flow
    v, u = optical_flow_tvl1(template, img)
    # use the estimated optical flow for registration
    nr, nc = template.shape
    row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc),
                                         indexing='ij')
    registered = warp(img, np.array([row_coords + v, col_coords + u]), mode='edge')
    return registered
Code
opt = RegistrationAlgorithm(optical_flow)
l1_losses, ncc_values, ssim_values = opt.evaluate_registration()
print("L1 losses:", f"{np.mean(l1_losses):.2f}")
print("Normalized cross-correlation values:", f"{np.mean(ncc_values):.2f}")
print("Structural similarity index values:", f"{np.mean(ssim_values):.2f}")
L1 losses: 0.08
Normalized cross-correlation values: 0.73
Structural similarity index values: 0.63
Code
visualise_registration_results(opt.registered_images, images, template, l1_losses)

SimpleElastix

The main idea is to solve a pairwise optimisation problem by minimising the cost function C. The optimisation can be formulated as \hat{T} = \text{argmin}_T C(T, I_f, I_m) with cost function defined as C(T, I_f, I_m) = -S(T, I_f, I_m) + \gamma P(T) where T is the transformation matrix, S is the similarity measurement and P is the penalty term with regulariser parameter \gamma.

SimpleElastix is based on the parametric approach to solve the optimisation problem, where the number of possible transformations are limited by introducing a parametrisation (model) of the transform. The optimisation becomes \hat{T}_\mu = \text{argmin}_{T_\mu} C(T_\mu, I_f, I_m) T_\mu denotes the parametrisation model and vector \mu contains the values of the transformation parameters. For 2D rigid transformation, the parameter vector \mu contains one rotation angle and the translation in x and y direction.

Code
import SimpleITK as sitk

def simple_elastix(image, template):
    # Convert the input images to SimpleITK images
    moving_image = sitk.GetImageFromArray(image)
    fixed_image = sitk.GetImageFromArray(template)

    # Create the registration method
    registration_method = sitk.DemonsRegistrationFilter()

    # Set the parameters
    registration_method.SetNumberOfIterations(100)
    registration_method.SetStandardDeviations(0.01)

    # Execute the registration
    registered_image = registration_method.Execute(fixed_image, moving_image)

    # Convert the result to a numpy array
    registered_image = sitk.GetArrayFromImage(registered_image)

    # Extract the first component (displacement field) from the registered image
    registered_image = np.mean(registered_image, axis=2)

    return registered_image

And now we preprocess the images as before, getting the template and the images to be registered:

Code
# retrieve images to be registered, and the image to register to
images, template = retrieve_images()

# perform the registration using SimpleElastix
opt = RegistrationAlgorithm(simple_elastix)
l1_losses, ncc_values, ssim_values = opt.evaluate_registration()
print("L1 losses:", f"{np.mean(l1_losses):.2f}")
print("Normalized cross-correlation values:", f"{np.mean(ncc_values):.2f}")
print("Structural similarity index values:", f"{np.mean(ssim_values):.2f}")
L1 losses: 1.34
Normalized cross-correlation values: -0.06
Structural similarity index values: 0.07

ORB

Code
import numpy as np
import cv2

def colour_to_greyscale(image):
    grey_image = 0.299 * image[..., 0] + 0.587 * image[..., 1] + 0.114 * image[..., 2]
    return grey_image.astype(np.uint8)

2D Voxelmorph with Spatial Transformer Network

Code
import torch

def voxelmorph_loss_2d(source, target, source_weight=1, target_weight=1, smoothness_weight=0.001):
    def gradient(x):
        d_dx = x[:, :, :-1, :-1] - x[:, :, 1:, :-1]
        d_dy = x[:, :, :-1, :-1] - x[:, :, :-1, 1:]
        return d_dx, d_dy

    def gradient_penalty(x):
        d_dx, d_dy = gradient(x)
        return (d_dx.abs().mean() + d_dy.abs().mean()) * smoothness_weight
    
    reconstruction_loss = (source - target).abs().mean() * target_weight
    smoothness_penalty = gradient_penalty(target)
    return reconstruction_loss + smoothness_penalty